/*
* Copyright 2016 The Simple File Server Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.sfs.encryption;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.services.kms.AWSKMSClient;
import com.amazonaws.services.kms.model.DecryptRequest;
import com.amazonaws.services.kms.model.EncryptRequest;
import com.amazonaws.services.kms.model.ReEncryptRequest;
import com.google.common.base.Preconditions;
import io.vertx.core.Context;
import io.vertx.core.json.JsonObject;
import io.vertx.core.logging.Logger;
import io.vertx.core.logging.LoggerFactory;
import org.sfs.Server;
import org.sfs.SfsVertx;
import org.sfs.VertxContext;
import org.sfs.rx.Defer;
import org.sfs.rx.RxHelper;
import org.sfs.util.ConfigHelper;
import rx.Observable;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicBoolean;
public class AwsKms implements Kms {
private static final Logger LOGGER = LoggerFactory.getLogger(AwsKms.class);
private Properties properties;
private AWSKMSClient kms;
private String keyId;
private String accessKeyId;
private String secretKey;
private AtomicBoolean started = new AtomicBoolean(false);
public AwsKms() {
}
public Observable<Void> start(VertxContext<Server> vertxContext,
JsonObject config) {
AwsKms _this = this;
SfsVertx sfsVertx = vertxContext.vertx();
Context context = sfsVertx.getOrCreateContext();
return Defer.aVoid()
.filter(aVoid -> started.compareAndSet(false, true))
.flatMap(aVoid -> {
String keyStoreAwsKmsEndpoint = ConfigHelper.getFieldOrEnv(config, "keystore.aws.kms.endpoint");
Preconditions.checkArgument(keyStoreAwsKmsEndpoint != null, "keystore.aws.kms.endpoint is required");
_this.keyId = ConfigHelper.getFieldOrEnv(config, "keystore.aws.kms.key_id");
Preconditions.checkArgument(_this.keyId != null, "keystore.aws.kms.key_id is required");
_this.accessKeyId = ConfigHelper.getFieldOrEnv(config, "keystore.aws.kms.access_key_id");
Preconditions.checkArgument(_this.accessKeyId != null, "keystore.aws.kms.access_key_id is required");
_this.secretKey = ConfigHelper.getFieldOrEnv(config, "keystore.aws.kms.secret_key");
Preconditions.checkArgument(_this.secretKey != null, "keystore.aws.kms.secret_key is required");
return RxHelper.executeBlocking(context, sfsVertx.getBackgroundPool(),
() -> {
kms = new AWSKMSClient(new AWSCredentials() {
@Override
public String getAWSAccessKeyId() {
return _this.accessKeyId;
}
@Override
public String getAWSSecretKey() {
return _this.secretKey;
}
});
kms.setEndpoint(keyStoreAwsKmsEndpoint);
return (Void) null;
});
})
.singleOrDefault(null);
}
public String getKeyId() {
return keyId;
}
@Override
public Observable<Encrypted> encrypt(VertxContext<Server> vertxContext, byte[] plainBytes) {
SfsVertx sfsVertx = vertxContext.vertx();
Context context = sfsVertx.getOrCreateContext();
return Observable.defer(() -> {
byte[] cloned = Arrays.copyOf(plainBytes, plainBytes.length);
return RxHelper.executeBlocking(context, sfsVertx.getBackgroundPool(), () -> {
try {
EncryptRequest req =
new EncryptRequest()
.withKeyId(keyId)
.withPlaintext(ByteBuffer.wrap(cloned));
ByteBuffer buffer = kms.encrypt(req).getCiphertextBlob();
byte[] b = new byte[buffer.remaining()];
buffer.get(b);
return new Encrypted(b, String.format("xppsaws:%s", keyId));
} finally {
Arrays.fill(cloned, (byte) 0);
}
});
});
}
@Override
public Observable<Encrypted> reencrypt(VertxContext<Server> vertxContext, byte[] cipherBytes) {
SfsVertx sfsVertx = vertxContext.vertx();
Context context = sfsVertx.getOrCreateContext();
return Observable.defer(() -> RxHelper.executeBlocking(context, sfsVertx.getBackgroundPool(), () -> {
ReEncryptRequest req =
new ReEncryptRequest()
.withDestinationKeyId(keyId)
.withCiphertextBlob(ByteBuffer.wrap(cipherBytes.clone()));
ByteBuffer buffer = kms.reEncrypt(req).getCiphertextBlob();
byte[] b = new byte[buffer.remaining()];
buffer.get(b);
return new Encrypted(b, keyId);
}));
}
@Override
public Observable<byte[]> decrypt(VertxContext<Server> vertxContext, byte[] cipherBytes) {
SfsVertx sfsVertx = vertxContext.vertx();
Context context = sfsVertx.getOrCreateContext();
return Observable.defer(() -> RxHelper.executeBlocking(context, sfsVertx.getBackgroundPool(), () -> {
DecryptRequest req =
new DecryptRequest()
.withCiphertextBlob(ByteBuffer.wrap(cipherBytes.clone()));
ByteBuffer buffer = kms.decrypt(req).getPlaintext();
byte[] b = new byte[buffer.remaining()];
buffer.get(b);
return b;
}));
}
public Observable<Void> stop(VertxContext<Server> vertxContext) {
SfsVertx sfsVertx = vertxContext.vertx();
Context context = sfsVertx.getOrCreateContext();
return Defer.aVoid()
.filter(aVoid -> started.compareAndSet(true, false))
.flatMap(aVoid -> {
if (properties != null) {
properties.clear();
properties = null;
}
if (kms != null) {
return RxHelper.executeBlocking(context, sfsVertx.getBackgroundPool(), () -> {
try {
kms.shutdown();
} catch (Throwable e) {
LOGGER.warn("Unhandled Exception", e);
}
return (Void) null;
});
}
return Defer.aVoid();
})
.singleOrDefault(null);
}
}